jax-metal segmentation fault in lax.scan

Copying from https://github.com/google/jax/issues/20750:

import jax
import jax.numpy as jnp


def test_func(x, y):
    return x, y

def main():
    # Print available JAX devices
    print("JAX devices:", jax.devices())

    # Create two random matrices
    a = jnp.array([[1.0, 2.0], [3.0, 4.0]])
    b = jnp.array([[5.0, 6.0], [7.0, 8.0]])

    # Perform matrix multiplication
    c = jnp.dot(a, b)

    # Print the result
    print("Result of matrix multiplication:")
    print(c)

    # Compute the gradient of sum of c with respect to a
    grad_a = jax.grad(lambda a: jnp.sum(jnp.dot(a, b)))(a)
    print("Gradient with respect to a:")
    print(grad_a)

    rng = jax.random.PRNGKey(0)
    test_input = jax.random.normal(key=rng, shape=(5,5,5))
    initial_state = jax.numpy.array(0.0)

    x, y = jax.lax.scan(test_func, initial_state, test_input)

if __name__ == "__main__":
    main()

Gets:

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-04-15 18:22:28.994752: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M2 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

JAX devices: [METAL(id=0)]
Result of matrix multiplication:
[[19. 22.]
 [43. 50.]]
Gradient with respect to a:
[[11. 15.]
 [11. 15.]]
zsh: segmentation fault  python JAXTest.py

With more info from the debugger:

Current thread 0x00000001fdd3bac0 (most recent call first):
  File "/Users/.../anaconda3/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1213 in __call__

My configuration is: jax-metal : 0.0.6 jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.24.3 python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64') macOS 14.4.1 (23E224)

Before in 3.9+0.0.3 etc it wasn't happening.

Update: segmentation fault still occurs after updating to the second Sonoma 14.5 Beta.

jax-metal segmentation fault in lax.scan
 
 
Q